{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bb460e3d",
   "metadata": {},
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torch import nn\n",
    "from torch.utils import data\n",
    "from torchvision import transforms\n",
    "from d2l import torch as d2l\n",
    "\n",
    "\n",
    "def get_dataloader_workers():\n",
    "    \"\"\"使用4个进程来读取数据\"\"\"\n",
    "    return 4\n",
    "\n",
    "\n",
    "def load_data_fashion_mnist(batch_size, resize=None):\n",
    "    \"\"\"下载Fashion-MNIST数据集，然后将其加载到内存中\"\"\"\n",
    "    trans = [transforms.ToTensor(), ]\n",
    "    # 如果图片过大，需要resize可以在这个函数里设定是否resize\n",
    "    if resize:\n",
    "        trans.insert(0, transforms.Resize(resize))\n",
    "    trans = transforms.Compose(trans)\n",
    "    mnist_train = torchvision.datasets.FashionMNIST(root=\"./data\", train=True, transform=trans, download=False)\n",
    "    mnist_test = torchvision.datasets.FashionMNIST(root=\"./data\", train=False, transform=trans, download=False)\n",
    "    return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()),\n",
    "            data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))\n",
    "\n",
    "\n",
    "# 1. 获取数据\n",
    "batch_size = 256\n",
    "train_iter, test_iter = load_data_fashion_mnist(batch_size)\n",
    "\n",
    "# 2. 初始化模型参数\n",
    "# PyTorch不会隐式地调整输入的形状。因此，我们在线性层前定义了展平层（flatten），来调整网络输入的形状\n",
    "net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))  # 28*28的图片平展到784，输出10个类\n",
    "\n",
    "\n",
    "def init_weights(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        nn.init.normal_(m.weight, std=0.01)\n",
    "\n",
    "\n",
    "net.apply(init_weights)\n",
    "\n",
    "# 3. 损失\n",
    "loss = nn.CrossEntropyLoss(reduction='none')\n",
    "\n",
    "# 4. 优化算法\n",
    "trainer = torch.optim.SGD(net.parameters(), lr=0.1)\n",
    "\n",
    "# 5. 训练\n",
    "num_epochs = 10\n",
    "d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)  # 展示动态图片\n"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ab95c0e",
   "metadata": {},
   "source": [],
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:anaconda3-pytorch] *",
   "language": "python",
   "name": "conda-env-anaconda3-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
